{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "gpuType": "T4"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "# Improve retrieval performance by Fine-tuning embedding model\n",
        "\n",
        "Another way to improve retriever performance is to fine-tune the embedding model itself. Fine-tuning the embedding model can help in learning better representations for the documents and queries in the dataset. This can be particularly useful when the dataset is very different from the pre-trained data used to train the embedding model."
      ],
      "metadata": {
        "id": "rYMbEXANHZ0B"
      }
    },
    {
      "cell_type": "code",
      "execution_count": 24,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "collapsed": true,
        "id": "6T7bwebVquFE",
        "outputId": "55bea6d1-631f-409e-9b7b-cb441d26102a"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
            "cudf-cu12 24.4.1 requires pyarrow<15.0.0a0,>=14.0.1, but you have pyarrow 12.0.1 which is incompatible.\n",
            "datasets 2.20.0 requires pyarrow>=15.0.0, but you have pyarrow 12.0.1 which is incompatible.\u001b[0m\u001b[31m\n",
            "\u001b[0m"
          ]
        }
      ],
      "source": [
        "%pip install llama-index-llms-openai llama-index-embeddings-openai llama-index-finetuning llama-index-readers-file scikit-learn llama-index-embeddings-huggingface llama-index-vector-stores-lancedb pyarrow==12.0.1 -qq"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# For eval utils\n",
        "!git clone https://github.com/lancedb/ragged.git\n",
        "!cd ragged && pip install .\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "6RRNyCDJDEcQ",
        "outputId": "bbcb0689-e82f-4593-f53c-77c3443a929d"
      },
      "execution_count": 22,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Cloning into 'ragged'...\n",
            "remote: Enumerating objects: 160, done.\u001b[K\n",
            "remote: Counting objects: 100% (160/160), done.\u001b[K\n",
            "remote: Compressing objects: 100% (103/103), done.\u001b[K\n",
            "remote: Total 160 (delta 70), reused 125 (delta 41), pack-reused 0\u001b[K\n",
            "Receiving objects: 100% (160/160), 38.15 KiB | 9.54 MiB/s, done.\n",
            "Resolving deltas: 100% (70/70), done.\n",
            "Processing /content/ragged\n",
            "  Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
            "  Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
            "  Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
            "Collecting datasets (from ragged==0.1.dev0)\n",
            "  Downloading datasets-2.20.0-py3-none-any.whl (547 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m547.8/547.8 kB\u001b[0m \u001b[31m13.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: lancedb in /usr/local/lib/python3.10/dist-packages (from ragged==0.1.dev0) (0.9.0)\n",
            "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from ragged==0.1.dev0) (2.0.3)\n",
            "Collecting streamlit (from ragged==0.1.dev0)\n",
            "  Downloading streamlit-1.36.0-py2.py3-none-any.whl (8.6 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m8.6/8.6 MB\u001b[0m \u001b[31m54.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: tantivy in /usr/local/lib/python3.10/dist-packages (from ragged==0.1.dev0) (0.22.0)\n",
            "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from datasets->ragged==0.1.dev0) (3.15.4)\n",
            "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from datasets->ragged==0.1.dev0) (1.25.2)\n",
            "Collecting pyarrow>=15.0.0 (from datasets->ragged==0.1.dev0)\n",
            "  Downloading pyarrow-16.1.0-cp310-cp310-manylinux_2_28_x86_64.whl (40.8 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m40.8/40.8 MB\u001b[0m \u001b[31m14.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: pyarrow-hotfix in /usr/local/lib/python3.10/dist-packages (from datasets->ragged==0.1.dev0) (0.6)\n",
            "Collecting dill<0.3.9,>=0.3.0 (from datasets->ragged==0.1.dev0)\n",
            "  Downloading dill-0.3.8-py3-none-any.whl (116 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m116.3/116.3 kB\u001b[0m \u001b[31m20.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hCollecting requests>=2.32.2 (from datasets->ragged==0.1.dev0)\n",
            "  Downloading requests-2.32.3-py3-none-any.whl (64 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m64.9/64.9 kB\u001b[0m \u001b[31m11.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: tqdm>=4.66.3 in /usr/local/lib/python3.10/dist-packages (from datasets->ragged==0.1.dev0) (4.66.4)\n",
            "Collecting xxhash (from datasets->ragged==0.1.dev0)\n",
            "  Downloading xxhash-3.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m194.1/194.1 kB\u001b[0m \u001b[31m29.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hCollecting multiprocess (from datasets->ragged==0.1.dev0)\n",
            "  Downloading multiprocess-0.70.16-py310-none-any.whl (134 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m24.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: fsspec[http]<=2024.5.0,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from datasets->ragged==0.1.dev0) (2023.6.0)\n",
            "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets->ragged==0.1.dev0) (3.9.5)\n",
            "Requirement already satisfied: huggingface-hub>=0.21.2 in /usr/local/lib/python3.10/dist-packages (from datasets->ragged==0.1.dev0) (0.23.4)\n",
            "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets->ragged==0.1.dev0) (24.1)\n",
            "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets->ragged==0.1.dev0) (6.0.1)\n",
            "Requirement already satisfied: deprecation in /usr/local/lib/python3.10/dist-packages (from lancedb->ragged==0.1.dev0) (2.1.0)\n",
            "Requirement already satisfied: pylance==0.13.0 in /usr/local/lib/python3.10/dist-packages (from lancedb->ragged==0.1.dev0) (0.13.0)\n",
            "Requirement already satisfied: ratelimiter~=1.0 in /usr/local/lib/python3.10/dist-packages (from lancedb->ragged==0.1.dev0) (1.2.0.post0)\n",
            "Requirement already satisfied: retry>=0.9.2 in /usr/local/lib/python3.10/dist-packages (from lancedb->ragged==0.1.dev0) (0.9.2)\n",
            "Requirement already satisfied: pydantic>=1.10 in /usr/local/lib/python3.10/dist-packages (from lancedb->ragged==0.1.dev0) (2.8.0)\n",
            "Requirement already satisfied: attrs>=21.3.0 in /usr/local/lib/python3.10/dist-packages (from lancedb->ragged==0.1.dev0) (23.2.0)\n",
            "Requirement already satisfied: cachetools in /usr/local/lib/python3.10/dist-packages (from lancedb->ragged==0.1.dev0) (5.3.3)\n",
            "Requirement already satisfied: overrides>=0.7 in /usr/local/lib/python3.10/dist-packages (from lancedb->ragged==0.1.dev0) (7.7.0)\n",
            "Collecting pyarrow>=15.0.0 (from datasets->ragged==0.1.dev0)\n",
            "  Downloading pyarrow-15.0.0-cp310-cp310-manylinux_2_28_x86_64.whl (38.3 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m38.3/38.3 MB\u001b[0m \u001b[31m12.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->ragged==0.1.dev0) (2.8.2)\n",
            "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->ragged==0.1.dev0) (2023.4)\n",
            "Requirement already satisfied: tzdata>=2022.1 in /usr/local/lib/python3.10/dist-packages (from pandas->ragged==0.1.dev0) (2024.1)\n",
            "Requirement already satisfied: altair<6,>=4.0 in /usr/local/lib/python3.10/dist-packages (from streamlit->ragged==0.1.dev0) (4.2.2)\n",
            "Requirement already satisfied: blinker<2,>=1.0.0 in /usr/lib/python3/dist-packages (from streamlit->ragged==0.1.dev0) (1.4)\n",
            "Requirement already satisfied: click<9,>=7.0 in /usr/local/lib/python3.10/dist-packages (from streamlit->ragged==0.1.dev0) (8.1.7)\n",
            "Requirement already satisfied: pillow<11,>=7.1.0 in /usr/local/lib/python3.10/dist-packages (from streamlit->ragged==0.1.dev0) (9.4.0)\n",
            "Requirement already satisfied: protobuf<6,>=3.20 in /usr/local/lib/python3.10/dist-packages (from streamlit->ragged==0.1.dev0) (3.20.3)\n",
            "Requirement already satisfied: rich<14,>=10.14.0 in /usr/local/lib/python3.10/dist-packages (from streamlit->ragged==0.1.dev0) (13.7.1)\n",
            "Requirement already satisfied: tenacity<9,>=8.1.0 in /usr/local/lib/python3.10/dist-packages (from streamlit->ragged==0.1.dev0) (8.3.0)\n",
            "Requirement already satisfied: toml<2,>=0.10.1 in /usr/local/lib/python3.10/dist-packages (from streamlit->ragged==0.1.dev0) (0.10.2)\n",
            "Requirement already satisfied: typing-extensions<5,>=4.3.0 in /usr/local/lib/python3.10/dist-packages (from streamlit->ragged==0.1.dev0) (4.12.2)\n",
            "Collecting gitpython!=3.1.19,<4,>=3.0.7 (from streamlit->ragged==0.1.dev0)\n",
            "  Downloading GitPython-3.1.43-py3-none-any.whl (207 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m207.3/207.3 kB\u001b[0m \u001b[31m22.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hCollecting pydeck<1,>=0.8.0b4 (from streamlit->ragged==0.1.dev0)\n",
            "  Downloading pydeck-0.9.1-py2.py3-none-any.whl (6.9 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m6.9/6.9 MB\u001b[0m \u001b[31m63.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: tornado<7,>=6.0.3 in /usr/local/lib/python3.10/dist-packages (from streamlit->ragged==0.1.dev0) (6.3.3)\n",
            "Collecting watchdog<5,>=2.1.5 (from streamlit->ragged==0.1.dev0)\n",
            "  Downloading watchdog-4.0.1-py3-none-manylinux2014_x86_64.whl (83 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m83.0/83.0 kB\u001b[0m \u001b[31m11.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: entrypoints in /usr/local/lib/python3.10/dist-packages (from altair<6,>=4.0->streamlit->ragged==0.1.dev0) (0.4)\n",
            "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from altair<6,>=4.0->streamlit->ragged==0.1.dev0) (3.1.4)\n",
            "Requirement already satisfied: jsonschema>=3.0 in /usr/local/lib/python3.10/dist-packages (from altair<6,>=4.0->streamlit->ragged==0.1.dev0) (4.19.2)\n",
            "Requirement already satisfied: toolz in /usr/local/lib/python3.10/dist-packages (from altair<6,>=4.0->streamlit->ragged==0.1.dev0) (0.12.1)\n",
            "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->ragged==0.1.dev0) (1.3.1)\n",
            "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->ragged==0.1.dev0) (1.4.1)\n",
            "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->ragged==0.1.dev0) (6.0.5)\n",
            "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->ragged==0.1.dev0) (1.9.4)\n",
            "Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->ragged==0.1.dev0) (4.0.3)\n",
            "Collecting gitdb<5,>=4.0.1 (from gitpython!=3.1.19,<4,>=3.0.7->streamlit->ragged==0.1.dev0)\n",
            "  Downloading gitdb-4.0.11-py3-none-any.whl (62 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m62.7/62.7 kB\u001b[0m \u001b[31m10.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: annotated-types>=0.4.0 in /usr/local/lib/python3.10/dist-packages (from pydantic>=1.10->lancedb->ragged==0.1.dev0) (0.7.0)\n",
            "Requirement already satisfied: pydantic-core==2.20.0 in /usr/local/lib/python3.10/dist-packages (from pydantic>=1.10->lancedb->ragged==0.1.dev0) (2.20.0)\n",
            "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.2->pandas->ragged==0.1.dev0) (1.16.0)\n",
            "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets->ragged==0.1.dev0) (3.3.2)\n",
            "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets->ragged==0.1.dev0) (3.7)\n",
            "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets->ragged==0.1.dev0) (2.0.7)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets->ragged==0.1.dev0) (2024.6.2)\n",
            "Requirement already satisfied: decorator>=3.4.2 in /usr/local/lib/python3.10/dist-packages (from retry>=0.9.2->lancedb->ragged==0.1.dev0) (4.4.2)\n",
            "Requirement already satisfied: py<2.0.0,>=1.4.26 in /usr/local/lib/python3.10/dist-packages (from retry>=0.9.2->lancedb->ragged==0.1.dev0) (1.11.0)\n",
            "Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from rich<14,>=10.14.0->streamlit->ragged==0.1.dev0) (3.0.0)\n",
            "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from rich<14,>=10.14.0->streamlit->ragged==0.1.dev0) (2.16.1)\n",
            "Collecting smmap<6,>=3.0.1 (from gitdb<5,>=4.0.1->gitpython!=3.1.19,<4,>=3.0.7->streamlit->ragged==0.1.dev0)\n",
            "  Downloading smmap-5.0.1-py3-none-any.whl (24 kB)\n",
            "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->altair<6,>=4.0->streamlit->ragged==0.1.dev0) (2.1.5)\n",
            "Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.10/dist-packages (from jsonschema>=3.0->altair<6,>=4.0->streamlit->ragged==0.1.dev0) (2023.12.1)\n",
            "Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.10/dist-packages (from jsonschema>=3.0->altair<6,>=4.0->streamlit->ragged==0.1.dev0) (0.35.1)\n",
            "Requirement already satisfied: rpds-py>=0.7.1 in /usr/local/lib/python3.10/dist-packages (from jsonschema>=3.0->altair<6,>=4.0->streamlit->ragged==0.1.dev0) (0.18.1)\n",
            "Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py>=2.2.0->rich<14,>=10.14.0->streamlit->ragged==0.1.dev0) (0.1.2)\n",
            "Building wheels for collected packages: ragged\n",
            "  Building wheel for ragged (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
            "  Created wheel for ragged: filename=ragged-0.1.dev0-py3-none-any.whl size=24662 sha256=d086741b289188a92153223fdb65db69f9297a523c7874746fd1669f7d3f9c07\n",
            "  Stored in directory: /tmp/pip-ephem-wheel-cache-q327t6y_/wheels/aa/3f/b0/d70e6f86074491db9b0bc7431c11f0138f2ed2359151509cf7\n",
            "Successfully built ragged\n",
            "Installing collected packages: xxhash, watchdog, smmap, requests, pyarrow, dill, pydeck, multiprocess, gitdb, gitpython, datasets, streamlit, ragged\n",
            "  Attempting uninstall: requests\n",
            "    Found existing installation: requests 2.31.0\n",
            "    Uninstalling requests-2.31.0:\n",
            "      Successfully uninstalled requests-2.31.0\n",
            "  Attempting uninstall: pyarrow\n",
            "    Found existing installation: pyarrow 12.0.1\n",
            "    Uninstalling pyarrow-12.0.1:\n",
            "      Successfully uninstalled pyarrow-12.0.1\n",
            "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
            "cudf-cu12 24.4.1 requires pyarrow<15.0.0a0,>=14.0.1, but you have pyarrow 15.0.0 which is incompatible.\n",
            "google-colab 1.0.0 requires requests==2.31.0, but you have requests 2.32.3 which is incompatible.\u001b[0m\u001b[31m\n",
            "\u001b[0mSuccessfully installed datasets-2.20.0 dill-0.3.8 gitdb-4.0.11 gitpython-3.1.43 multiprocess-0.70.16 pyarrow-15.0.0 pydeck-0.9.1 ragged-0.1.dev0 requests-2.32.3 smmap-5.0.1 streamlit-1.36.0 watchdog-4.0.1 xxhash-3.4.1\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## The dataset\n",
        "The dataset we'll use is a synthetic QA dataset generated from LLama2 review paper. The paper was divided into chunks, with each chunk being a unique context. An LLM was prompted to ask questions relevant to the context for testing a retriever.\n",
        "The exact code and other utility functions for this can be found in [this](https://github.com/lancedb/ragged) repo\n"
      ],
      "metadata": {
        "id": "B_2S_b0c3pdp"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!wget https://raw.githubusercontent.com/AyushExel/assets/main/data_qa.csv"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "4QFDh3jD3d1X",
        "outputId": "642f53c8-a084-4c34-db6a-bfee35abbd28"
      },
      "execution_count": 8,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "--2024-07-09 20:37:46--  https://raw.githubusercontent.com/AyushExel/assets/main/data_qa.csv\n",
            "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.108.133, 185.199.109.133, ...\n",
            "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.\n",
            "HTTP request sent, awaiting response... 200 OK\n",
            "Length: 680439 (664K) [text/plain]\n",
            "Saving to: ‘data_qa.csv’\n",
            "\n",
            "data_qa.csv         100%[===================>] 664.49K  --.-KB/s    in 0.006s  \n",
            "\n",
            "2024-07-09 20:37:47 (100 MB/s) - ‘data_qa.csv’ saved [680439/680439]\n",
            "\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "import pandas as pd\n",
        "\n",
        "data = pd.read_csv(\"data_qa.csv\")"
      ],
      "metadata": {
        "id": "AIF2zczc3kwW"
      },
      "execution_count": 9,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Pre-processing\n",
        "Now we need to parse the context(corpus) of the dataset as llama-index text nodes.  "
      ],
      "metadata": {
        "id": "_xV40VSy3twE"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from pathlib import Path\n",
        "from llama_index.core.node_parser import SentenceSplitter\n",
        "from llama_index.readers.file import PagedCSVReader\n",
        "\n",
        "def load_corpus(file, verbose=False):\n",
        "    if verbose:\n",
        "        print(f\"Loading files {file}...\")\n",
        "\n",
        "    loader = PagedCSVReader(encoding=\"utf-8\")\n",
        "    docs = loader.load_data(file=Path(file))\n",
        "\n",
        "    if verbose:\n",
        "        print(f\"Loaded {len(docs)} docs\")\n",
        "\n",
        "    parser = SentenceSplitter()\n",
        "    nodes = parser.get_nodes_from_documents(docs, show_progress=verbose)\n",
        "\n",
        "    if verbose:\n",
        "        print(f\"Parsed {len(nodes)} nodes\")\n",
        "\n",
        "    return nodes"
      ],
      "metadata": {
        "id": "mzDZYUX4qxBC"
      },
      "execution_count": 10,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import pandas as pd\n",
        "\n",
        "df = pd.read_csv(\"data_qa.csv\", index_col=0)"
      ],
      "metadata": {
        "id": "eoLOdNO-4HbV"
      },
      "execution_count": 11,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import os\n",
        "\n",
        "os.environ[\"OPENAI_API_KEY\"] = \"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\""
      ],
      "metadata": {
        "id": "EqsFZ5KYqzvg"
      },
      "execution_count": 12,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Split into train and validation sets. We'll use the original df for val as that has different queries generated via a different prompt.\n"
      ],
      "metadata": {
        "id": "zrwa35x96FLZ"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from sklearn.model_selection import train_test_split\n",
        "\n",
        "# Randomly shuffle df.\n",
        "#df = df.sample(frac=1, random_state=42)\n",
        "\n",
        "train_df, val_df = train_test_split(df, test_size=0.2, random_state=42)\n",
        "\n",
        "train_df.to_csv(\"train_data_qa.csv\", index=False)\n",
        "val_df.to_csv(\"val_data_qa.csv\", index=False)"
      ],
      "metadata": {
        "id": "diHhY9Ipq9Uw"
      },
      "execution_count": 13,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "train_nodes = load_corpus(\"train_data_qa.csv\", verbose=True)\n",
        "val_nodes = load_corpus(\"val_data_qa.csv\", verbose=True)"
      ],
      "metadata": {
        "id": "C7PKGtXPq_Fc",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 188,
          "referenced_widgets": [
            "3c85bfeaccc84a47844c770fa1fb2511",
            "7461a200b0634607ac479708e3cba537",
            "156a3c94ba094fbf86e70681c69ca31a",
            "4261378f06ed48cc8cef251cd2c096ab",
            "59aeeeae529a440fab4c231501fce4f6",
            "1308faaa9fa944b2b17e96e8cd9a9445",
            "83c0d87febcc4dfaa95b4d3e2005a416",
            "f503b1be4e2e42c8bf4460eea2f1bb07",
            "8228bf5a569844d584003446649731a6",
            "2255fb2d83734ef88843ffe47116da84",
            "9f30b1969bf24b86b8deaa41ea7231f6",
            "f55c2f3c448741819e618b44bc0b1976",
            "b0bad294bb6443388b77f854c4f77569",
            "812696b2a65c4ca281da45f286ab95cf",
            "9a026ebe3c8b416e9c5c4d7dd05bba66",
            "af4dfd45973d466cb5c78d002c723cd6",
            "211ee4b118154b0a94cbc686fdf90c55",
            "3929b1c14657468792c74c6610598af5",
            "222f778312d745aebc6f1d33c651dca8",
            "d65d92433f304e389e0ce8aa7baf7155",
            "a81c89c5c5a64ad0938d8d1e9789838c",
            "95a2e33e8d244d24813b386e60301a2a"
          ]
        },
        "outputId": "bcb428bd-5d02-444c-e456-22260402faa8"
      },
      "execution_count": 14,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Loading files train_data_qa.csv...\n",
            "Loaded 176 docs\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Parsing nodes:   0%|          | 0/176 [00:00<?, ?it/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "3c85bfeaccc84a47844c770fa1fb2511"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Parsed 221 nodes\n",
            "Loading files val_data_qa.csv...\n",
            "Loaded 44 docs\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Parsing nodes:   0%|          | 0/44 [00:00<?, ?it/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "f55c2f3c448741819e618b44bc0b1976"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Parsed 59 nodes\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Generate the query from context from training\n"
      ],
      "metadata": {
        "id": "zMUSQPmkIu9N"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from llama_index.finetuning import generate_qa_embedding_pairs\n",
        "from llama_index.core.evaluation import EmbeddingQAFinetuneDataset"
      ],
      "metadata": {
        "id": "oPBGWH2or8_T"
      },
      "execution_count": 15,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from llama_index.llms.openai import OpenAI\n",
        "\n",
        "\n",
        "train_dataset = generate_qa_embedding_pairs(\n",
        "    llm=OpenAI(model=\"gpt-3.5-turbo\"), nodes=train_nodes, verbose=False\n",
        ")\n",
        "val_dataset = generate_qa_embedding_pairs(\n",
        "    llm=OpenAI(model=\"gpt-3.5-turbo\"), nodes=val_nodes, verbose=False\n",
        ")\n",
        "\n",
        "train_dataset.save_json(\"train_dataset.json\")\n",
        "val_dataset.save_json(\"val_dataset.json\")"
      ],
      "metadata": {
        "id": "VKLRRVzBr_dc",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "3aff6b29-0e87-4862-d82c-668a3652a711"
      },
      "execution_count": 16,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "100%|██████████| 221/221 [05:29<00:00,  1.49s/it]\n",
            "221it [00:00, ?it/s]\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Load again\n",
        "train_dataset = EmbeddingQAFinetuneDataset.from_json(\"train_dataset.json\")\n",
        "\n",
        "val_dataset = EmbeddingQAFinetuneDataset.from_json(\"val_dataset.json\")"
      ],
      "metadata": {
        "id": "cAFS_uZ-sThk"
      },
      "execution_count": 17,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Fine-tune the embedding model"
      ],
      "metadata": {
        "id": "GeOJCE51I2ay"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "from llama_index.finetuning import SentenceTransformersFinetuneEngine\n",
        "\n",
        "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "finetune_engine = SentenceTransformersFinetuneEngine(\n",
        "    train_dataset,\n",
        "    model_id=\"BAAI/bge-small-en-v1.5\",\n",
        "    model_output_path=\"tuned_model\",\n",
        "    val_dataset=val_dataset,\n",
        "    device=device\n",
        ")"
      ],
      "metadata": {
        "id": "9qDt2sH8sXYd",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 532,
          "referenced_widgets": [
            "e4b65385be4a40c8be59bd5d4fd33617",
            "492f25dcc6c642e18a12f57ef8b1c4bb",
            "715d4554f7de446dac06bbba077d2943",
            "af3c1a4d76854bf3ad2a1342c487490c",
            "fea3179057d64835909e614259584b24",
            "42a9a1d7a10e410bbce7336426a89eaf",
            "5ca598736e8d499f98d663a4e8468236",
            "0191245bc5564818bd0998684937e3f9",
            "005d5811d7e74a1eaf96454670e284bf",
            "ad27603f7b574c34b88b62fd5b13c370",
            "5d9b950ddac14d909b22f6a44d69dd7f",
            "61ffdbde1d3c41c985500c7dc0d0bb74",
            "eda0f3fcfd4c438b9e93cd7d76deaa31",
            "114e03e8b55146dfadbce1bfa4bb47f0",
            "6d21b28b0f3e40f29dc848d7cdaf9e65",
            "e058f4ef1ba4486295730eb1366fa807",
            "8fa97d81b8a845cdba358590a424e4f2",
            "12b397147f524979b992b23923dfa27b",
            "de7125ac21b24c44b808fe9bc0e94480",
            "45f6f78e1710479e8d92818142d15acf",
            "f54a557a93164d70a5e8b0f928fc597d",
            "2d7d5750f065464fb2f555f19ad28e2f",
            "d17df34105d541b4915bc7e6e9e1c730",
            "308cc174e3bb4505b088f30c313141be",
            "ecd3f44ad2c1406e9de08bbcd0eeca95",
            "31c353c1aa3f4be291819cc60235d26b",
            "c552bd86dcbd4aedb8480eed8b881dce",
            "04a79a8f0bf145fc994e1781b4e671a8",
            "9a05dd4db9bb4fef9ec4a8bce5293447",
            "31a32a223ccc432e87402e65d02d7262",
            "39d4188d85db4dfb8c2e596c92889e69",
            "7ef9a7e7af62407a8e5e54591365a5a7",
            "1f841a50b9e7435db9d8bf112b2139f5",
            "8c91c4fbea4c4d30a65598ec246fb839",
            "b23cc832cce44d01800a2d88cf2a90d0",
            "4621e32fc06d4c6f9742538b9426f27f",
            "88844faaadc249e38531f8fccf729bbe",
            "a31b27a1ff6149229c820c48ce89a1c6",
            "a57b26d8a8db4809832d27a80b78980e",
            "24b6864a33f943f2a5eb2e83dbf2418b",
            "f9361eabe75f49f9924dc91d4777a394",
            "16e7616dbf104684b917c28d8f127b4e",
            "79db201f9f214b49bf656d91edae63bc",
            "ce5e12cef02646a0bc6b83ce6ba9b254",
            "98499261c8274a9ab98922d177095294",
            "be1e04b1cc5b41a39447caafc6f4d0ea",
            "3201efc82fea4f55ba73badf51fee83c",
            "d0423a17fd464ae89f53aca72e881dab",
            "d7734c3c92b34653948b0534c6f4d555",
            "b5be64c76a1e40e5bd4c538dabc3c824",
            "ebb8217cd27d4f13aaad28ec2c58341d",
            "d05fb76e27364cb8bbe2bf2d264e244f",
            "5833232265cf412cab10b3509d0e46f9",
            "503ba67bb01041bf869c1e6a690a5fad",
            "54c6c1b5e9a94df5b23192b067c49f0a",
            "99bad626bca445c5a932592afb35abed",
            "cd439dfd1f3e4310a92c82cf5561a9ce",
            "fd7e509508e54fb69a59a58f07125030",
            "7660d0dbf36547b9838d0b462d3fa572",
            "a7510910eeb244b7bd6835809209b219",
            "4ec7b8d718f1417d8809b83700e41cc1",
            "106e605bbe394affa5f48e97f3aebcf1",
            "c1f2ac8ae31844a6a9d44ba345694595",
            "3544ca20a36e4551b29ec3b9075000e4",
            "f532dc0483de46be9da3feee5e7a6259",
            "8f05db60c37c485c9e77b2a3e014af11",
            "97497a14eaee45dbad617ac9ff9fff86",
            "2483cadd8a1f4fa5a55ac4f76c64c0b8",
            "c0b9a14e45c8473a852f0c1697e2797a",
            "b00d139e72e148c0895084f115d79bdc",
            "c625566ea3744d39b51b6647ae19edcc",
            "e35e506796b443e1a7b4a9154c62b14b",
            "f05a245a98dd4ab1a2125c222e6be9b8",
            "4651e3d6f7c249b39cac7e308dc80f13",
            "7ec8ecc6bdf1495ab167af44a2bcd654",
            "5e95c7e6831342a1acbd9e71ac43e57c",
            "4a7315c5405a4eab99910e3ab3eb01bf",
            "029195730530468dbbff0a21f0c366fe",
            "8af1379845aa42ff9dbcd6d27f783ab7",
            "3cb7c9bf226940ebbdcb4a9673c64a10",
            "a0da1cd7d7304a0b8357199d3d03b740",
            "49633c3119bb4f218d473dd860e04ee8",
            "f6bd25ca0d7a45c9bdaaad13d63865e5",
            "e45cbe865e964272a5cc10f78291b5c2",
            "7c727ddb56754832a74604b850c0e366",
            "707a5d42da404f578e8369009d436524",
            "ffd5638ae9244e9ca560b48d51875eec",
            "5d515357946d4f9d842381542fc407e5",
            "f71748dfd59846ffa6da01552ea7c37a",
            "9712eac12a8442d8a0bf0448410b1af3",
            "34c92004ec7b40beb2fa33f71618c608",
            "4bdd7d1976ee4cb3a4faba87a5368c68",
            "fc19c541426f4cf399b647131743a01b",
            "4c0f8b4fa80948b88b34e613beb47bff",
            "83f08debe4794039bd8080daa4b37d9a",
            "34677a7e86034a51915002eb7b9140ed",
            "b8848722ce174207b66c6b083d240062",
            "2008efeda4ad4245add43c324759480c",
            "530c758442c64faea91ff0e9a152b6c6",
            "190d4c074c2a4e5ca538001fcbef3e27",
            "2c542fb465cf4d7c90364d953063a563",
            "711c47520e2a45f69dc99018a0c0d436",
            "93e93576bc5049e6b3694299c2791d97",
            "69a59f0b20ce4281922b91f0155235b5",
            "59d362fddedb41db91ae20d82b6d7098",
            "8aeae7b1c9784bd1b3de79cd2cdab7a5",
            "2609a0e30677499785d9aa2e227ae8cb",
            "18affcb7a8394160bd3390f351f196cc",
            "e7e810edfdbf490584f427460dcd1922",
            "b384ae9262d145cb8da620ba1e6a30ac",
            "85f7243f06354a1fb14e49be3a683ff2",
            "728c4b65c062447fb10c13cf242edab6",
            "9cc727b8d4f547aeb44578c48f16deb3",
            "aa1673989de249fd8fca50daecc30b56",
            "59ceed5dbff6497da5c757235f2b1f56",
            "73fe958190144b3ebb056cc149f7056c",
            "9bf605e0314240eab7ac3bf8e362216e",
            "de80f3e14fd94584a15fdc99519a3623",
            "1cbcf1a6ab8d442a8e64a32080b2d852",
            "9f21d9bdabf74963b166fb09baa11cad",
            "02223aa9266d4b288f3648820beba55f"
          ]
        },
        "outputId": "adf4fd51-ad9e-4bd0-91f4-f187ef35ff8b"
      },
      "execution_count": 18,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:89: UserWarning: \n",
            "The secret `HF_TOKEN` does not exist in your Colab secrets.\n",
            "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n",
            "You will be able to reuse this secret in all of your notebooks.\n",
            "Please note that authentication is recommended but still optional to access public models or datasets.\n",
            "  warnings.warn(\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "e4b65385be4a40c8be59bd5d4fd33617"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "config_sentence_transformers.json:   0%|          | 0.00/124 [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "61ffdbde1d3c41c985500c7dc0d0bb74"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "README.md:   0%|          | 0.00/94.8k [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "d17df34105d541b4915bc7e6e9e1c730"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "sentence_bert_config.json:   0%|          | 0.00/52.0 [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "8c91c4fbea4c4d30a65598ec246fb839"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.10/dist-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
            "  warnings.warn(\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "config.json:   0%|          | 0.00/743 [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "98499261c8274a9ab98922d177095294"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "model.safetensors:   0%|          | 0.00/133M [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "99bad626bca445c5a932592afb35abed"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "tokenizer_config.json:   0%|          | 0.00/366 [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "97497a14eaee45dbad617ac9ff9fff86"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "029195730530468dbbff0a21f0c366fe"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "tokenizer.json:   0%|          | 0.00/711k [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "f71748dfd59846ffa6da01552ea7c37a"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "190d4c074c2a4e5ca538001fcbef3e27"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "1_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "85f7243f06354a1fb14e49be3a683ff2"
            }
          },
          "metadata": {}
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "finetune_engine.finetune( )\n"
      ],
      "metadata": {
        "id": "mkUNUehKsatO",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 113,
          "referenced_widgets": [
            "845dc253ec8f4973866800694d346748",
            "85fc16048366474da0578a1625920ea6",
            "d823de79a4794df99c2f92b6930f53cd",
            "5c50db266bab42c39c227902c6e104ce",
            "fa0cbfd581504cd6bf9683094e7eb744",
            "42a204225a604d63a11cac636b61d208",
            "569e0c2b605a4b6184a23e79a5f687c4",
            "dacfd12cbad044109fe193a99f0972a3",
            "ec640d05e2e747d09946f4b10f8b5bce",
            "3f3d2416ef2b46999447dcbd5542bb7c",
            "dfad9784f50f4d18bd1755e7b5114297",
            "0678d496424a46deb77d1a026083306b",
            "7ad6ce012ca24d70a3f7b8850d5fb034",
            "b8e6000ee14440348d31a2d7f46604eb",
            "f3fea079347640b2b67e3027a8cd8162",
            "30a21f7fb0d4408088e96c0ab2d40c09",
            "5395bcc56d2d4466b392781441f132a7",
            "e619cc5a34bc4e27afac06c5285fc1b4",
            "1b25e30b577c418d8e2e0d80cea4481f",
            "4802eecc59494767aa042ed2a2296f0c",
            "1e4cd904189940f38f51c013dfa50622",
            "858461487c9e4bf6bf8bc8abab575994",
            "c4965d3a5609482d94651dc29d435375",
            "1ed65863aef74be8aad450d779bb5ff2",
            "a28f0bdef19c4a4590d43827875ef9d7",
            "e9d6b5958d7143b99db3dd41289d5038",
            "55b91ccaa7fd49dab1668a636d0b686c",
            "a79590322ad74dd386698ddd2bb34bc6",
            "dcd0e411aa234758990d332869b471ac",
            "06f9435d6b5648b097e4001745e0216f",
            "353a0f45713e4041a5a2dded1fd20c3f",
            "0f9df48e885b4469a03780a199765e2a",
            "672aba50226d41a59761a02580ea4ffd"
          ]
        },
        "outputId": "ee3226a5-213f-45d5-d794-8bbb0f8b1c4a"
      },
      "execution_count": 19,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Epoch:   0%|          | 0/2 [00:00<?, ?it/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "845dc253ec8f4973866800694d346748"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Iteration:   0%|          | 0/45 [00:00<?, ?it/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "0678d496424a46deb77d1a026083306b"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Iteration:   0%|          | 0/45 [00:00<?, ?it/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "c4965d3a5609482d94651dc29d435375"
            }
          },
          "metadata": {}
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "embed_model = finetune_engine.get_finetuned_model()\n"
      ],
      "metadata": {
        "id": "VaSa6IMksbvg"
      },
      "execution_count": 20,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Evaluate on Hit-rate\n"
      ],
      "metadata": {
        "id": "8OsyCxqGI6_y"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from ragged.dataset import CSVDataset, SquadDataset\n",
        "from ragged.rag import llamaIndexRAG\n",
        "from ragged.metrics.retriever.hit_rate import HitRate\n",
        "from ragged.search_utils import QueryType\n",
        "\n",
        "\n",
        "def evaluate_vector(\n",
        "    dataset,\n",
        "    embed_model_name_or_path,\n",
        "    top_k=5,\n",
        "):\n",
        "  dataset = CSVDataset(dataset)\n",
        "\n",
        "  hit_rate = HitRate(dataset, embed_model_kwarg={\"name\": embed_model_name_or_path})\n",
        "\n",
        "  print(hit_rate.evaluate(top_k, query_type=QueryType.VECTOR))\n",
        "\n",
        "\n",
        "def evaluate_all(\n",
        "    dataset,\n",
        "    embed_model_name_or_path,\n",
        "    reranker,\n",
        "    top_k=5,\n",
        "):\n",
        "  dataset = CSVDataset(dataset)\n",
        "  hit_rate = HitRate(dataset, embed_model_kwarg={\"name\": embed_model_name_or_path}, reranker=reranker)\n",
        "\n",
        "  print(hit_rate.evaluate(top_k, query_type=QueryType.ALL))\n"
      ],
      "metadata": {
        "id": "AAyIXEXwse8F"
      },
      "execution_count": 25,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from lancedb.rerankers import CohereReranker, LinearCombinationReranker\n",
        "\n",
        "\n",
        "#linear_combination_reranker = LinearCombinationReranker()\n",
        "cohere_reranker = CohereReranker(api_key=\"Jp48Rt3QuO4VSLWiFKhbgnx68QaDueC9XEqvWMQZ\")\n",
        "\n",
        "#evaluate_all(\"data_qa.csv\", \"BAAI/bge-small-en-v1.5\", linear_combination_reranker)\n",
        "hit_rate_bge_cohere = evaluate_all(\"data_qa.csv\", \"BAAI/bge-small-en-v1.5\", cohere_reranker)\n"
      ],
      "metadata": {
        "id": "7_lm0QsI0CMG",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "e950fac5-0c8d-4cf1-9eaf-ef57dfb02b6b"
      },
      "execution_count": 28,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "INFO:lancedb:Adding 110 documents to LanceDB, in 1 batches of size 110\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Adding 110 documents to LanceDB, in 1 batches of size 110\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Adding batch to LanceDB: 100%|██████████| 110/110 [00:00<00:00, 165663.71it/s]\n",
            "INFO:lancedb:Adding batch 0 to LanceDB\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Adding batch 0 to LanceDB\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "INFO:lancedb:created table with length 110\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "created table with length 110\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "INFO:lancedb:Evaluating query type: vector\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Evaluating query type: vector\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "100%|██████████| 220/220 [00:10<00:00, 20.61it/s]\n",
            "INFO:lancedb:Hit rate for vector: 0.6409090909090909\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Hit rate for vector: 0.6409090909090909\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "INFO:lancedb:Evaluating query type: fts\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Evaluating query type: fts\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "100%|██████████| 220/220 [00:00<00:00, 361.50it/s]\n",
            "INFO:lancedb:Hit rate for fts: 0.5954545454545455\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Hit rate for fts: 0.5954545454545455\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "INFO:lancedb:Evaluating query type: rerank_vector\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Evaluating query type: rerank_vector\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "100%|██████████| 220/220 [01:32<00:00,  2.38it/s]\n",
            "INFO:lancedb:Hit rate for rerank_vector: 0.6772727272727272\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Hit rate for rerank_vector: 0.6772727272727272\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "INFO:lancedb:Evaluating query type: rerank_fts\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Evaluating query type: rerank_fts\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "100%|██████████| 220/220 [01:23<00:00,  2.63it/s]\n",
            "INFO:lancedb:Hit rate for rerank_fts: 0.6727272727272727\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Hit rate for rerank_fts: 0.6727272727272727\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "INFO:lancedb:Evaluating query type: hybrid\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Evaluating query type: hybrid\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "100%|██████████| 220/220 [01:28<00:00,  2.47it/s]\n",
            "INFO:lancedb:Hit rate for hybrid: 0.759090909090909\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Hit rate for hybrid: 0.759090909090909\n",
            "vector=0.6409090909090909 fts=0.5954545454545455 rerank_vector=0.6772727272727272 rerank_fts=0.6727272727272727 hybrid=0.759090909090909\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "#evaluate_all(\"data_qa.csv\", \"tuned_model/\", linear_combination_reranker)\n",
        "evaluate_all(\"data_qa.csv\", \"tuned_model/\", cohere_reranker)\n",
        "\n"
      ],
      "metadata": {
        "id": "s-axsDpN1PRw",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "4bfcce8c-5c6e-4411-dcfd-fef8f264c05f"
      },
      "execution_count": 29,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "INFO:lancedb:Adding 110 documents to LanceDB, in 1 batches of size 110\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Adding 110 documents to LanceDB, in 1 batches of size 110\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Adding batch to LanceDB: 100%|██████████| 110/110 [00:00<00:00, 91234.61it/s]\n",
            "INFO:lancedb:Adding batch 0 to LanceDB\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Adding batch 0 to LanceDB\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "INFO:lancedb:created table with length 110\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "created table with length 110\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "INFO:lancedb:Evaluating query type: vector\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Evaluating query type: vector\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "100%|██████████| 220/220 [00:09<00:00, 22.17it/s]\n",
            "INFO:lancedb:Hit rate for vector: 0.6727272727272727\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Hit rate for vector: 0.6727272727272727\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "INFO:lancedb:Evaluating query type: fts\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Evaluating query type: fts\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "100%|██████████| 220/220 [00:00<00:00, 285.43it/s]\n",
            "INFO:lancedb:Hit rate for fts: 0.5954545454545455\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Hit rate for fts: 0.5954545454545455\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "INFO:lancedb:Evaluating query type: rerank_vector\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Evaluating query type: rerank_vector\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "100%|██████████| 220/220 [01:29<00:00,  2.45it/s]\n",
            "INFO:lancedb:Hit rate for rerank_vector: 0.7545454545454545\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Hit rate for rerank_vector: 0.7545454545454545\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "INFO:lancedb:Evaluating query type: rerank_fts\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Evaluating query type: rerank_fts\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "100%|██████████| 220/220 [01:22<00:00,  2.66it/s]\n",
            "INFO:lancedb:Hit rate for rerank_fts: 0.6727272727272727\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Hit rate for rerank_fts: 0.6727272727272727\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "INFO:lancedb:Evaluating query type: hybrid\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Evaluating query type: hybrid\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "100%|██████████| 220/220 [01:28<00:00,  2.48it/s]\n",
            "INFO:lancedb:Hit rate for hybrid: 0.7681818181818182\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Hit rate for hybrid: 0.7681818181818182\n",
            "vector=0.6727272727272727 fts=0.5954545454545455 rerank_vector=0.7545454545454545 rerank_fts=0.6727272727272727 hybrid=0.7681818181818182\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "FEJvpkC3Nyns"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}